/*
    Copyright 2006-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Contains classes for drawing objects from probability distributions.

// $Id$

#ifndef __montecarlo__
#define __montecarlo__

#include "blitz/array.h"

#include "random.h"
#include <vector>

namespace mcrx {
  template<typename, template<typename> class, typename> class probability_distribution;
  template <typename> class rejection_sampling;
  template <typename> class cumulative_sampling;
}

/** Used to randomly draw objects based on their probability. Contains
    pointers to objects and their statistical weights. This class uses
    two policy classes, sampling_policy determines whether the
    distribution should be sampled using the cumulative probability
    distribution (cumulative_sampling) or using rejection sampling
    (rejection_sampling). */
template <typename T, 
	  template <typename> class T_sampling_policy = mcrx::cumulative_sampling,
	  typename T_rng_policy = mcrx::global_random>
class mcrx::probability_distribution : public T_sampling_policy<T_rng_policy> {
private:
  /// Pointers to the objects.
  std::vector<T*> pointers_;

public:
  probability_distribution() {}; 
  probability_distribution(const std::vector<T*>& p, const array_1& w) :
    T_sampling_policy<T_rng_policy>(w), pointers_(p) {};

  /** This function can be used to set up the object after construction. */
  void setup(const std::vector<T*>& p, const array_1& w) {
    pointers_=p;
    this->set_weights(w);};

  /// Return a pointer to one of the objects randomly drawn according
  /// to the sampling weights.
  T* sample(T_rng_policy& p) const {
    const int i=this->draw_entry(p); 
    assert(i>=0);assert(i<pointers_.size());
    return pointers_[i];
  };

  /// Return a const version of the pointer vector.
  const std::vector<T*>& pointers() const {return pointers_;};
};

/** Policy class for sampling objects based on rejection sampling,
    which doesn't require calculating the cumulative probability
    distribution. Rejection sampling is likely slower than cumulative
    sampling, especially for very nonuniform distributions, but it
    does not require a recalculation when weights are changed.  */ 
template <typename T_rng_policy = mcrx::global_random>
class mcrx::rejection_sampling { 
private: 
  /// Statistical weights for the objects.  We use an array because
  /// it's easy to reference other objects without copying.
  array_1 weights_;
  T_float max_; ///< Maximum statistical weight.

protected:
  /// Destructor is protected to prevent clients from deleting this. 
  //~rejection_sampling () {}; 

public:
  rejection_sampling () {};
  rejection_sampling (const array_1& w): weights_(w), max_(max (w)) {};

  /** Sets the statistical weights of the emitters when the maximum
      value is unknown.  This involves a search for the maximum value,
      so is slightly slower than if the max value is supplied. Note
      that this uses a reference to the supplied array, the data are
      not copied, so it's fast.  */
  void set_weights (const array_1& w) {set_weights (w, max (w));};
  /** Sets the statistical weights of the emitters when the maximum
      value (or a reasonably good lower limit on it) is known.  Note
      that this uses a reference to the supplied array, the data are
      not copied, so it's fast.  */
  void set_weights (const array_1& w, T_float max) {
    weights_.reference(w); max_= max;};

  /** Returns the total statistical weight of all entries.  Note that
      because this function needs to sum up all the weights, it's
      slow.  */
  T_float total_weight () const {
    return sum(weights_);};

  /** Returns the weight of entry i. */
  T_float weight (int i) const {
    return weights_(i);};

  /** Draw an entry using the specified random number policy.  */
  int draw_entry (T_rng_policy& p) const {
    bool reject = true; int i;
    while (reject) {
      i = int(floor(p.rnd()*weights_.size()));
      reject = (p.rnd () > weights_(i)/max_);
    }
    assert(i>=0);
    assert(i<weights_.size());
    return i;
  };
};

/** Policy class for sampling objects based on searching the
    cumulative probability distribution. This is generally fast (it
    only requires a binary search), but it requires recalculating the
    cumulative distribution if any of the weights are changed.  */
template <typename T_rng_policy = mcrx::global_random>
class mcrx::cumulative_sampling {
private: 
  /// Cumulative probability distribution. 
  std::vector<T_float> pdf_;

protected:
  /// Destructor is protected to prevent clients from deleting this. 
  //~cumulative_sampling () {}; 

public:
  cumulative_sampling () {};
  cumulative_sampling (const array_1& w) {set_weights(w);};
  
  /** Sets the statistical weights of the emitters.  This involves
      summing up the cumulative distribution, so should not be changed
      very frequently.  */
  void set_weights (const array_1& w) {
    pdf_.resize(w.size());
    pdf_[0] = w(0);
    for (int i = 1; i < pdf_.size(); ++i)
      pdf_ [i] = pdf_ [i- 1] +w (i);
  };

  /** Returns the total statistical weight of all emitters.  */
  T_float total_weight () const {
    return pdf_.back();};

  /** Returns the weight of entry i. */
  T_float weight (int i) const {
    return pdf_[i] - ( (i>0)?pdf_[i-1] : 0 );};

  /** Draw an entry using the specified random number policy.  */
  int draw_entry (T_rng_policy& p) const {
    T_float r=p.rnd()*total_weight();
    const std::vector<T_float>::const_iterator i =
      std::lower_bound (pdf_.begin(), pdf_.end(), r);
    const int ii=(i-pdf_.begin());
    assert(ii>=0);
    assert(ii<pdf_.size());
    return ii;
  };
};

#endif
